{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.simplefilter('ignore')\n",
    "\n",
    "import logging\n",
    "\n",
    "import re\n",
    "\n",
    "import pandas as pd\n",
    "pd.set_option('max_rows', 500)\n",
    "pd.set_option('max_colwidth', 100)\n",
    "\n",
    "from tqdm import tqdm\n",
    "tqdm.pandas()\n",
    "\n",
    "from simpletransformers.ner import NERModel, NERArgs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "train = pd.read_csv('raw_data/train.csv')\n",
    "train.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_medicine = {\n",
    "    '甲霜锰锌': \"锰锌\",\n",
    "    '烯酰锰锌': \"锰锌\",\n",
    "    '霜脲锰锌': \"锰锌\",\n",
    "    '恶霜锰锌': \"锰锌\",\n",
    "    '春雷王铜': \"王铜\",\n",
    "    '阿维哒螨灵': \"哒螨灵\",\n",
    "    '苯甲丙环唑': \"丙环唑\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = list()\n",
    "\n",
    "for i, row in tqdm(train.iterrows()):\n",
    "    id_ = i\n",
    "    text = re.sub(\"[-\\d\\.%·\\+，。％一＋]\", \"\", row['text'])\n",
    "    text = text.replace('多   /w', '多/w')\n",
    "    for d in clean_medicine:\n",
    "        text = text.replace(d, clean_medicine[d])\n",
    "\n",
    "    for item in text.split():\n",
    "        item = item.replace(' ', '')\n",
    "        try:\n",
    "            w, lbl = item.split('/')\n",
    "            if lbl not in ['n_crop', 'n_disease', 'n_medicine']:\n",
    "                for c in w:\n",
    "                    train_data.append([id_, c, 'O'])\n",
    "            elif lbl == 'n_crop':\n",
    "                if len(w) < 2:\n",
    "                    print(f\"word len < 2: {w}\")\n",
    "                else:\n",
    "                    train_data.append([id_, w[0], 'B_crop'])\n",
    "                    if len(w) > 2:\n",
    "                        for c in w[1:-1]:\n",
    "                            train_data.append([id_, c, 'I_crop'])\n",
    "                    train_data.append([id_, w[-1], 'E_crop'])\n",
    "            elif lbl == 'n_disease':\n",
    "                if len(w) < 2:\n",
    "                    print(f\"word len < 2: {w}\")\n",
    "                else:\n",
    "                    train_data.append([id_, w[0], 'B_disease'])\n",
    "                    if len(w) > 2:\n",
    "                        for c in w[1:-1]:\n",
    "                            train_data.append([id_, c, 'I_disease'])\n",
    "                    train_data.append([id_, w[-1], 'E_disease'])\n",
    "            elif lbl == 'n_medicine':\n",
    "                if len(w) < 2:\n",
    "                    print(f\"word len < 2: {w}\")\n",
    "                else:\n",
    "                    train_data.append([id_, w[0], 'B_medicine'])\n",
    "                    if len(w) > 2:\n",
    "                        for c in w[1:-1]:\n",
    "                            train_data.append([id_, c, 'I_medicine'])\n",
    "                    train_data.append([id_, w[-1], 'E_medicine'])\n",
    "        except:\n",
    "            item = re.sub(r'/[a-z]+', '', item)\n",
    "            for c in item:\n",
    "                train_data.append([id_, c, 'O'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = pd.DataFrame(\n",
    "    train_data, columns=[\"sentence_id\", \"words\", \"labels\"]\n",
    ")\n",
    "\n",
    "train_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data.labels.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = [\n",
    "    'B_crop',\n",
    "    'I_crop',\n",
    "    'E_crop',\n",
    "    'B_disease',\n",
    "    'I_disease',\n",
    "    'E_disease',\n",
    "    'B_medicine',\n",
    "    'I_medicine',\n",
    "    'E_medicine',\n",
    "    'O'\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_data = train_data[train_data['sentence_id'] >= len(train)-300]\n",
    "eval_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = train_data[train_data['sentence_id'] < len(train)-300]\n",
    "\n",
    "train_data.shape, eval_data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_args = NERArgs()\n",
    "model_args.train_batch_size = 8\n",
    "model_args.num_train_epochs = 5\n",
    "model_args.fp16 = False\n",
    "model_args.evaluate_during_training = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = NERModel(\"bert\", \n",
    "                 \"hfl/chinese-bert-wwm-ext\",\n",
    "                 labels=labels,\n",
    "                 args=model_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.train_model(train_data, eval_data=eval_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result, model_outputs, preds_list = model.eval_model(eval_data)\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = pd.read_csv('raw_data/test.csv')\n",
    "test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data = list()\n",
    "\n",
    "for i, row in tqdm(test.iterrows()):\n",
    "    id_ = i\n",
    "    text = re.sub(\"[-\\d\\.%·\\+，。％一＋ ]\", \"\", row['text'])\n",
    "    text = text.replace('多   /w', '多/w')\n",
    "    for d in clean_medicine:\n",
    "        text = text.replace(d, clean_medicine[d])\n",
    " \n",
    "    preds, _ = model.predict([text], split_on_space=False)\n",
    "\n",
    "    n_crop = list()\n",
    "    n_disease = list()\n",
    "    n_medicine = list()\n",
    "    \n",
    "    new_li = list()\n",
    "    for i in preds[0]:\n",
    "        for ch, lb in i.items():\n",
    "            new_li.append([ch, lb])\n",
    "            \n",
    "    max_ = len(new_li)\n",
    "    for i in range(max_):\n",
    "        w = list()\n",
    "        ch1, lb1 = new_li[i]\n",
    "        if lb1 == 'B_crop':\n",
    "            w.append(ch1)\n",
    "            for j in range(i+1, max_):\n",
    "                ch2, lb2 = new_li[j]\n",
    "                if lb2 == 'I_crop' or lb2 == 'O':\n",
    "                    w.append(ch2)\n",
    "                elif lb2 == 'E_crop':\n",
    "                    w.append(ch2)\n",
    "                    n_crop.append(\"\".join(w))\n",
    "                    break\n",
    "        elif lb1 == 'B_disease':\n",
    "            w.append(ch1)\n",
    "            for j in range(i+1, max_):\n",
    "                ch2, lb2 = new_li[j]\n",
    "                if lb2 == 'I_disease' or lb2 == 'O':\n",
    "                    w.append(ch2)\n",
    "                elif lb2 == 'E_disease':\n",
    "                    w.append(ch2)\n",
    "                    n_disease.append(\"\".join(w))\n",
    "                    break\n",
    "        elif lb1 == 'B_medicine':\n",
    "            w.append(ch1)\n",
    "            for j in range(i+1, max_):\n",
    "                ch2, lb2 = new_li[j]\n",
    "                if lb2 == 'I_medicine' or lb2 == 'O':\n",
    "                    w.append(ch2)\n",
    "                elif lb2 == 'E_medicine':\n",
    "                    w.append(ch2)\n",
    "                    n_medicine.append(\"\".join(w))\n",
    "                    break\n",
    "                    \n",
    "    test_data.append([id_, n_crop, n_disease, n_medicine])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data = pd.DataFrame(\n",
    "    test_data, columns=['id', 'n_crop', 'n_disease', 'n_medicine']\n",
    ")\n",
    "\n",
    "test_data.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data.to_csv('submission.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
